{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from transformer_copy import CopyTaskModel\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 16\n",
    "model = CopyTaskModel()\n",
    "criteria = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_batch(batch_size, max_length=16):\n",
    "    src = []\n",
    "    for i in range(batch_size):\n",
    "        # 随机生成句子长度\n",
    "        random_len = random.randint(1, max_length - 2)\n",
    "        # 随机生成句子词汇，并在开头和结尾增加<bos>和<eos>\n",
    "        random_nums = [0] + [random.randint(3, 9) for _ in range(random_len)] + [1]\n",
    "        # 如果句子长度不足max_length，进行填充\n",
    "        random_nums = random_nums + [2] * (max_length - random_len - 2)\n",
    "        src.append(random_nums)\n",
    "    src = torch.LongTensor(src)\n",
    "    # tgt不要最后一个token\n",
    "    tgt = src[:, :-1]\n",
    "    # tgt_y不要第一个的token\n",
    "    tgt_y = src[:, 1:]\n",
    "    # 计算tgt_y，即要预测的有效token的数量\n",
    "    n_tokens = (tgt_y != 2).sum()\n",
    "\n",
    "    # 这里的n_tokens指的是我们要预测的tgt_y中有多少有效的token，后面计算loss要用\n",
    "    return src, tgt, tgt_y, n_tokens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 40, total_loss: 3.497683048248291\n",
      "Step 80, total_loss: 2.4187228679656982\n",
      "Step 120, total_loss: 2.2320830821990967\n",
      "Step 160, total_loss: 2.1231110095977783\n",
      "Step 200, total_loss: 1.980668306350708\n",
      "Step 240, total_loss: 1.999376654624939\n",
      "Step 280, total_loss: 1.8775198459625244\n",
      "Step 320, total_loss: 1.6650404930114746\n",
      "Step 360, total_loss: 1.6887805461883545\n",
      "Step 400, total_loss: 1.5717617273330688\n",
      "Step 440, total_loss: 1.7097562551498413\n",
      "Step 480, total_loss: 1.5512534379959106\n",
      "Step 520, total_loss: 1.4015393257141113\n",
      "Step 560, total_loss: 1.287029504776001\n",
      "Step 600, total_loss: 1.1649774312973022\n",
      "Step 640, total_loss: 1.1112967729568481\n",
      "Step 680, total_loss: 0.9656742811203003\n",
      "Step 720, total_loss: 1.0747123956680298\n",
      "Step 760, total_loss: 1.1003875732421875\n",
      "Step 800, total_loss: 0.8472160696983337\n",
      "Step 840, total_loss: 0.9125905632972717\n",
      "Step 880, total_loss: 0.9420027732849121\n",
      "Step 920, total_loss: 0.7851682901382446\n",
      "Step 960, total_loss: 0.6790415644645691\n",
      "Step 1000, total_loss: 0.7720445990562439\n",
      "Step 1040, total_loss: 0.6400890946388245\n",
      "Step 1080, total_loss: 0.7189700603485107\n",
      "Step 1120, total_loss: 0.5577905774116516\n",
      "Step 1160, total_loss: 0.6476354002952576\n",
      "Step 1200, total_loss: 0.5949708819389343\n",
      "Step 1240, total_loss: 0.5076698064804077\n",
      "Step 1280, total_loss: 0.5839860439300537\n",
      "Step 1320, total_loss: 0.7743046283721924\n",
      "Step 1360, total_loss: 0.8629148602485657\n",
      "Step 1400, total_loss: 0.4182184040546417\n",
      "Step 1440, total_loss: 0.5261658430099487\n",
      "Step 1480, total_loss: 0.3755952715873718\n",
      "Step 1520, total_loss: 0.41986680030822754\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 24\u001b[0m\n\u001b[1;32m     22\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m     23\u001b[0m \u001b[38;5;66;03m# 更新参数\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     26\u001b[0m total_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n\u001b[1;32m     28\u001b[0m \u001b[38;5;66;03m# 每40次打印一下loss\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/optimizer.py:385\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    380\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    381\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    382\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    383\u001b[0m             )\n\u001b[0;32m--> 385\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    386\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m    388\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     74\u001b[0m     torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m     75\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     78\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/adam.py:166\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    155\u001b[0m     beta1, beta2 \u001b[38;5;241m=\u001b[39m group[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m    157\u001b[0m     has_complex \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_group(\n\u001b[1;32m    158\u001b[0m         group,\n\u001b[1;32m    159\u001b[0m         params_with_grad,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    163\u001b[0m         max_exp_avg_sqs,\n\u001b[1;32m    164\u001b[0m         state_steps)\n\u001b[0;32m--> 166\u001b[0m     \u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    167\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    168\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    169\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    170\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    171\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    172\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    173\u001b[0m \u001b[43m        \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    174\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    175\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    176\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    177\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    178\u001b[0m \u001b[43m        \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    179\u001b[0m \u001b[43m        \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    180\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    181\u001b[0m \u001b[43m        \u001b[49m\u001b[43mforeach\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mforeach\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    182\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcapturable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    183\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdifferentiable\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    184\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfused\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    185\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgrad_scale\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    186\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    187\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    189\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/adam.py:316\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m    313\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    314\u001b[0m     func \u001b[38;5;241m=\u001b[39m _single_tensor_adam\n\u001b[0;32m--> 316\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    317\u001b[0m \u001b[43m     \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    318\u001b[0m \u001b[43m     \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    319\u001b[0m \u001b[43m     \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    320\u001b[0m \u001b[43m     \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    321\u001b[0m \u001b[43m     \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    322\u001b[0m \u001b[43m     \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    323\u001b[0m \u001b[43m     \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    324\u001b[0m \u001b[43m     \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    325\u001b[0m \u001b[43m     \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    326\u001b[0m \u001b[43m     \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    327\u001b[0m \u001b[43m     \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    328\u001b[0m \u001b[43m     \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    329\u001b[0m \u001b[43m     \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    330\u001b[0m \u001b[43m     \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    331\u001b[0m \u001b[43m     \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    332\u001b[0m \u001b[43m     \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    333\u001b[0m \u001b[43m     \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfound_inf\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/adam.py:430\u001b[0m, in \u001b[0;36m_single_tensor_adam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, has_complex, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)\u001b[0m\n\u001b[1;32m    426\u001b[0m bias_correction2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m step\n\u001b[1;32m    428\u001b[0m step_size \u001b[38;5;241m=\u001b[39m lr \u001b[38;5;241m/\u001b[39m bias_correction1\n\u001b[0;32m--> 430\u001b[0m bias_correction2_sqrt \u001b[38;5;241m=\u001b[39m \u001b[43m_dispatch_sqrt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbias_correction2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    432\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m amsgrad:\n\u001b[1;32m    433\u001b[0m     \u001b[38;5;66;03m# Maintains the maximum of all 2nd moment running avg. till now\u001b[39;00m\n\u001b[1;32m    434\u001b[0m     torch\u001b[38;5;241m.\u001b[39mmaximum(max_exp_avg_sqs[i], exp_avg_sq, out\u001b[38;5;241m=\u001b[39mmax_exp_avg_sqs[i])\n",
      "File \u001b[0;32m~/miniconda3/envs/trans/lib/python3.8/site-packages/torch/optim/optimizer.py:101\u001b[0m, in \u001b[0;36m_dispatch_sqrt\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m     99\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\u001b[38;5;241m.\u001b[39msqrt()\n\u001b[1;32m    100\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 101\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msqrt\u001b[49m(x)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "total_loss = 0\n",
    "\n",
    "for step in range(2000):\n",
    "    # 生成数据\n",
    "    src, tgt, tgt_y, n_tokens = generate_random_batch(batch_size=2, max_length=max_length)\n",
    "\n",
    "    # 清空梯度\n",
    "    optimizer.zero_grad()\n",
    "    # 进行transformer的计算\n",
    "    out = model(src, tgt)\n",
    "    # 将结果送给最后的线性层进行预测\n",
    "    out = model.predictor(out)\n",
    "    \"\"\"\n",
    "    计算损失。由于训练时我们的是对所有的输出都进行预测，所以需要对out进行reshape一下。\n",
    "            我们的out的Shape为(batch_size, 词数, 词典大小)，view之后变为：\n",
    "            (batch_size*词数, 词典大小)。\n",
    "            而在这些预测结果中，我们只需要对非<pad>部分进行，所以需要进行正则化。也就是\n",
    "            除以n_tokens。\n",
    "    \"\"\"\n",
    "    loss = criteria(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / n_tokens\n",
    "    # 计算梯度\n",
    "    loss.backward()\n",
    "    # 更新参数\n",
    "    optimizer.step()\n",
    "\n",
    "    total_loss += loss\n",
    "\n",
    "    # 每40次打印一下loss\n",
    "    if step != 0 and step % 40 == 0:\n",
    "        print(\"Step {}, total_loss: {}\".format(step, total_loss))\n",
    "        total_loss = 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()\n",
    "# 随便定义一个src\n",
    "src = torch.LongTensor([[0, 4, 3, 4, 6, 8, 9, 9, 8, 1, 2, 2]])\n",
    "# tgt从<bos>开始，看看能不能重新输出src中的值\n",
    "tgt = torch.LongTensor([[0]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 一个一个词预测，直到预测为<eos>，或者达到句子最大长度\n",
    "for i in range(max_length):\n",
    "    # 进行transformer计算\n",
    "    out = model(src, tgt)\n",
    "    # 预测结果，因为只需要看最后一个词，所以取`out[:, -1]`\n",
    "    predict = model.predictor(out[:, -1])\n",
    "    # 找出最大值的index\n",
    "    y = torch.argmax(predict, dim=1)\n",
    "    # 和之前的预测结果拼接到一起\n",
    "    tgt = torch.concat([tgt, y.unsqueeze(0)], dim=1)\n",
    "\n",
    "    # 如果为<eos>，说明预测结束，跳出循环\n",
    "    if y == 1:\n",
    "        break\n",
    "print(tgt)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trans",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
